Initialize Qwen3.5 mutable buffers during export#17801
Initialize Qwen3.5 mutable buffers during export#17801Phineas1500 wants to merge 2 commits intopytorch:mainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/17801
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 4 Unrelated FailuresAs of commit 4826903 with merge base 9d413ac ( NEW FAILURES - The following jobs have failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@pytorchbot label "release notes: none" |
There was a problem hiding this comment.
Pull request overview
This PR adds Qwen3.5 support to the Llama export pipeline with deterministic initialization of Qwen3.5’s internal mutable buffers (KV cache + DeltaNet recurrent/conv state) during export, and introduces the Qwen3.5 attention implementations/configs needed to run/export the hybrid layer layout.
Changes:
- Add Qwen3.5 model types/configs and HF weight conversion utilities for ExecuTorch “meta” format.
- Implement Qwen3.5 hybrid attention blocks (full attention + Gated DeltaNet linear attention) and wire hybrid layer construction into the Llama transformer.
- Factor/export additional mutable-buffer initialization pass selection (torchtune + Qwen3.5) into a shared helper and add unit tests for pass selection and attention state reset.
Reviewed changes
Copilot reviewed 21 out of 21 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/export/config/llm_config.py | Adds Qwen3.5 model types to the export config enum. |
| examples/models/qwen3_5/tests/test_convert_weights.py | Unit test for Qwen3.5 HF→meta key mapping. |
| examples/models/qwen3_5/tests/init.py | Package marker/license header for Qwen3.5 tests. |
| examples/models/qwen3_5/convert_weights.py | Implements Qwen3.5 checkpoint loading and key conversion (incl. legacy packed tensor splitting). |
| examples/models/qwen3_5/config/qwen3_5_xnnpack_fp32.yaml | Adds an fp32/static-shape XNNPACK export config for Qwen3.5. |
| examples/models/qwen3_5/config/4b_config.json | Adds model args for Qwen3.5 4B (hybrid layer_types etc.). |
| examples/models/qwen3_5/config/2b_config.json | Adds model args for Qwen3.5 2B. |
| examples/models/qwen3_5/config/0_8b_config.json | Adds model args for Qwen3.5 0.8B. |
| examples/models/qwen3_5/init.py | Adds a Qwen3.5 model entrypoint (lazy subclass of Llama2Model) and exports convert_weights. |
| examples/models/qwen3_5/README.md | Documents export/run instructions for Qwen3.5 models. |
| examples/models/qwen3_5/BUCK | Adds Buck target for the Qwen3.5 Python library + deps. |
| examples/models/llama/tests/test_qwen3_5_attention.py | Adds tests for Qwen3.5 full-attn shape and DeltaNet state reset behavior. |
| examples/models/llama/tests/test_export_llama_lib.py | Adds tests covering export-pass selection for Qwen3.5/torchtune/llama3. |
| examples/models/llama/tests/BUCK | Registers the new Qwen3.5 attention unittest target. |
| examples/models/llama/norm.py | Extends RMSNorm to support Qwen3.5 “(1 + weight)” scaling. |
| examples/models/llama/model_args.py | Adds Qwen3.5 linear-attention dims + RMSNorm scaling flag to ModelArgs with defaults. |
| examples/models/llama/llama_transformer.py | Wires RMSNorm scaling flag and constructs DeltaNet layers when layer_types specify linear_attention. |
| examples/models/llama/export_llama_lib.py | Adds Qwen3.5 model ids, hooks Qwen3.5 weight conversion, and factors mutable-buffer init pass selection into helper. |
| examples/models/llama/attention.py | Adds Qwen3.5 full attention and Gated DeltaNet attention implementations (+ required buffers). |
| examples/models/llama/init.py | Switches llama package export to lazy import pattern for Llama2Model. |
| examples/models/BUCK | Adds the Qwen3.5 model package to the umbrella models BUCK target. |
Comments suppressed due to low confidence (1)
examples/models/llama/norm.py:60
- RMSNorm currently returns
output * self.weightwhenadd_unit_offsetis False. Sinceoutputis cast back totype_as(x)butself.weightstays fp32, this multiplication will promote the result to fp32 for fp16/bf16 inputs. The newadd_unit_offsetbranch explicitly casts the weight totype_as(x), so the dtype behavior is now inconsistent between the two paths. Consider castingself.weighttotype_as(x)(or otherwise ensuring the output dtype matches the input) in the non-offset path as well.
return output * (1.0 + self.weight.float()).type_as(x)
return output * self.weight
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise ValueError( | ||
| f"Invalid packed in_proj_qkvz shape for {key}: {tuple(value.shape)}" | ||
| ) | ||
| qkv, z = torch.split(value, [conv_dim, value_dim], dim=0) |
There was a problem hiding this comment.
key_dim is computed when splitting legacy packed in_proj_qkvz.weight but is never used afterward. Please remove it or use it for an explicit shape validation to avoid dead code.
|
Validated export and runtime with the XNNPACK recipe. Set max_seq_len and max_context_len to 128, generated the .pte, and ran executorch.examples.models.llama.runner.native with a multi-token prompt. The model currently uses static-shape export in this path, but I added sequential token prefill. |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 25 out of 25 changed files in this pull request and generated 5 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| try: | ||
| return self.forward( | ||
| tokens=torch.tensor( | ||
| [prompt_tokens], dtype=torch.long, device=self.device | ||
| ), | ||
| input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), | ||
| ) | ||
| except RuntimeError: | ||
| # Some exported models use a static single-token shape for kv-cache mode. | ||
| # Fall back to sequential token prefill so multi-token prompts still work. | ||
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | ||
| raise | ||
|
|
||
| return self._sequential_kv_prefill(prompt_tokens, pos_base) | ||
|
|
There was a problem hiding this comment.
In _prefill_with_kv_cache, the early return for not self.enable_dynamic_shape and len(prompt_tokens) > 1 means the subsequent try/except RuntimeError never executes for the “static single-token shape” case described below. As written, the fallback logic is effectively dead code for the static-shape scenario; consider simplifying to a single path (either always sequential when static, or always try batched then fall back).
| try: | |
| return self.forward( | |
| tokens=torch.tensor( | |
| [prompt_tokens], dtype=torch.long, device=self.device | |
| ), | |
| input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), | |
| ) | |
| except RuntimeError: | |
| # Some exported models use a static single-token shape for kv-cache mode. | |
| # Fall back to sequential token prefill so multi-token prompts still work. | |
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | |
| raise | |
| return self._sequential_kv_prefill(prompt_tokens, pos_base) | |
| return self.forward( | |
| tokens=torch.tensor( | |
| [prompt_tokens], dtype=torch.long, device=self.device | |
| ), | |
| input_pos=torch.tensor([pos_base], dtype=torch.long, device=self.device), | |
| ) |
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | ||
| raise | ||
|
|
There was a problem hiding this comment.
The except RuntimeError fallback is currently unreachable when self.enable_dynamic_shape is True (the default): the handler re-raises whenever self.enable_dynamic_shape is true, so a static exported model that throws on batched prefill (and lacks the enable_dynamic_shape metadata method) will fail instead of falling back to sequential prefill. Consider falling back based on len(prompt_tokens) > 1 (and/or the specific error) rather than the enable_dynamic_shape flag, or updating the flag when the batched call fails.
| if self.enable_dynamic_shape or len(prompt_tokens) <= 1: | |
| raise | |
| # | |
| # If the batched prefill fails for a multi-token prompt, disable dynamic | |
| # shape support and retry using sequential prefill. For single-token | |
| # prompts, propagate the error. | |
| if len(prompt_tokens) <= 1: | |
| raise | |
| # Avoid retrying batched dynamic-shape prefill after a failure. | |
| self.enable_dynamic_shape = False |
examples/models/llama/norm.py
Outdated
| output = self._norm(x.float()).type_as(x) | ||
| if self.add_unit_offset: | ||
| return output * (1.0 + self.weight.float()).type_as(x) | ||
| return output * self.weight |
There was a problem hiding this comment.
RMSNorm.forward returns output * self.weight when add_unit_offset is false, which will promote dtypes (e.g., fp16 input → fp32 output) because self.weight is fp32. In the new add_unit_offset branch you explicitly cast the scale to type_as(x), so the output dtype now depends on the flag. Consider casting self.weight (or the final product) to type_as(x) in both branches to keep output dtype consistent with the input.
| return output * self.weight | |
| return output * self.weight.type_as(x) |
| try: | ||
| self.enable_dynamic_shape = bool( | ||
| self.model.run_method("enable_dynamic_shape")[0] | ||
| ) | ||
| except Exception: | ||
| # Keep default behavior when metadata method is unavailable. | ||
| pass |
There was a problem hiding this comment.
Catching a bare Exception around run_method("enable_dynamic_shape") can also hide real runtime issues (e.g., model load/ABI problems) and silently keep enable_dynamic_shape=True. It would be safer to catch the specific “method missing”/runtime exceptions raised by run_method (and optionally log at debug level) so unexpected failures don’t get swallowed.
| try: | ||
| new_key = get_mapped_key(normalized_key, _QWEN_3_5_TO_META) | ||
| except Exception: | ||
| # Ignore non-text weights and training-only extras (e.g., MTP). | ||
| if ( | ||
| key.startswith("mtp.") | ||
| or key.startswith("model.visual.") | ||
| or ".vision_" in key | ||
| or key.startswith("visual.") | ||
| ): | ||
| continue | ||
| # Ignore unsupported keys that are not required by the export model. | ||
| continue |
There was a problem hiding this comment.
The except Exception: ... continue around get_mapped_key will silently drop any unexpected keys (including genuinely required text weights if the mapping is incomplete or the checkpoint format changes). This makes conversion failures hard to detect. Consider only ignoring a well-defined allowlist of optional prefixes (vision/MTP/etc.) and re-raising for other model.* keys, or at least logging the first few unmapped keys at warning level.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 25 out of 25 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| ): | ||
| if _should_ignore_unmapped_key(key, normalized_key): | ||
| continue | ||
| continue |
There was a problem hiding this comment.
In the non-language-model key branch, _should_ignore_unmapped_key(...) currently has no effect because the code continues unconditionally whether the key is ignored or not. This makes it easy to silently drop unexpected checkpoint keys. Consider either (a) raising for non-ignored keys here, or (b) removing the ignore-check entirely if the intent is to ignore all non-text keys.
| continue | |
| raise ValueError( | |
| "Unexpected non-language-model checkpoint key not mapped for " | |
| f"Qwen3.5 export: {key}" | |
| ) |
| "model.layers.{}.linear_attn.in_proj_b.weight": "layers.{}.attention.in_proj_b.weight", | ||
| "model.layers.{}.linear_attn.in_proj_a.weight": "layers.{}.attention.in_proj_a.weight", | ||
| "model.layers.{}.linear_attn.conv1d.weight": "layers.{}.attention.conv1d.weight", | ||
| "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", |
There was a problem hiding this comment.
The converter maps model.layers.*.linear_attn.conv1d.bias to layers.*.attention.conv1d.bias, but the corresponding ExecuTorch module (AttentionGatedDeltaNet) constructs conv1d with bias=False (no conv1d.bias parameter). If HF checkpoints can include this bias, it would be better to ignore/drop it during conversion (or flip the module to bias=True) to avoid carrying an unused tensor and relying on strict=False loads.
| "model.layers.{}.linear_attn.conv1d.bias": "layers.{}.attention.conv1d.bias", |
| # Legacy packed tensors (older checkpoints): | ||
| # in_proj_qkvz -> split into in_proj_qkv and in_proj_z | ||
| # in_proj_ba -> split into in_proj_b and in_proj_a | ||
| if normalized_key.endswith(".linear_attn.in_proj_qkvz.weight"): | ||
| pending_qkvz[normalized_key] = value | ||
| continue | ||
| if normalized_key.endswith(".linear_attn.in_proj_ba.weight"): | ||
| pending_ba[normalized_key] = value | ||
| continue |
There was a problem hiding this comment.
The legacy packed-key handling (in_proj_qkvz / in_proj_ba) introduces non-trivial splitting logic that isn't covered by the new unit tests. Add a test case that includes a packed key plus the required out_proj.weight so the split shapes/keys (and error paths) are exercised.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 25 out of 25 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 25 out of 25 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def _get_additional_export_passes(model_class: str) -> List[InitializedMutableBufferPass]: | ||
| patterns = [] | ||
|
|
||
| if model_class in TORCHTUNE_DEFINED_MODELS: | ||
| patterns.append("kv_cache_pos") | ||
|
|
||
| # Qwen3.5 uses internal mutable buffers for both the hybrid KV path and | ||
| # DeltaNet recurrent/conv states. | ||
| if model_class.startswith("qwen3_5"): | ||
| patterns.extend( | ||
| [ | ||
| "k_cache", | ||
| "v_cache", | ||
| "conv_state", | ||
| "recurrent_state", | ||
| ] | ||
| ) | ||
|
|
||
| return [InitializedMutableBufferPass(patterns)] if patterns else [] |
There was a problem hiding this comment.
InitializedMutableBufferPass causes matched mutated buffers to be serialized with their initial values. Initializing large buffers like k_cache/v_cache will therefore increase the exported .pte size (and potentially load time / memory pressure) by the full KV-cache tensor sizes. If this is expected, consider adding a config flag (or model-class-specific opt-out) so callers can choose determinism vs artifact size, and document the expected size impact for Qwen3.5 exports.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 25 out of 25 changed files in this pull request and generated no new comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
95a76b7 to
4826903
Compare
|
@lucylq rebased |
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if model_class.startswith("qwen3_5"): | ||
| patterns.extend( | ||
| [ | ||
| "k_cache", | ||
| "v_cache", |
There was a problem hiding this comment.
Initializing KV cache buffers via InitializedMutableBufferPass will cause their full tensor contents to be serialized into the .pte (the emitter treats et_init_buffer+mutable_buffer as const). For k_cache/v_cache this can be extremely large (per-layer [B, H, S, D]) and may blow up export size and load time. Consider avoiding initializing the full KV caches at export (e.g., only init the small state buffers like conv_state/recurrent_state, or add a runtime/cache-reset path that deterministically zeros these buffers without serializing them).
| if model_class.startswith("qwen3_5"): | |
| patterns.extend( | |
| [ | |
| "k_cache", | |
| "v_cache", | |
| # Avoid initializing large KV cache buffers (k_cache/v_cache) here, since | |
| # InitializedMutableBufferPass would serialize their full contents into | |
| # the exported artifact, significantly increasing size and load time. | |
| if model_class.startswith("qwen3_5"): | |
| patterns.extend( | |
| [ |
There was a problem hiding this comment.
This is a good point. @Phineas1500 does qwen3.5 require initial state for the kv-cache, conv_state and recurrent_state?
The InitializedMutableBufferPass is only required for mutable buffers with initial state.
There was a problem hiding this comment.
Seems like ~5mb size increase from including initial state. Not too sure why - was expecting a bit more.
-rw-r--r-- 1 lfq users 4032780800 Mar 6 14:29 qwen3_5_0_8b_fp32_no_init.pte
-rw-r--r-- 1 lfq users 4038122240 Mar 5 11:05 qwen3_5_0_8b_fp32.pte
Output is the same with temp=0
(executorch) [lfq@devvm311.ldc0 /data/users/lfq/executorch (qwen3_5_phase2)]$ python -m executorch.examples
.models.llama.runner.native --model qwen3_5_0_8b --pte qwen3_5_0_8b_fp32_no_init.pte --tokenizer ~/.cache/h
uggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9cf38778875588b17/tokenizer.json
--tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3.5-0.8B/snapshots/2fc06364715b967f1860aea9c
f38778875588b17/tokenizer_config.json --params examples/models/qwen3_5/config/0_8b_config.json --prompt "<|
im_start|>user\nHello, what's 15% of 80?<|im_end|>\n<|im_start|>assistant\n" --max_len 128 -kv --temperatur
e 0
I tokenizers:regex.cpp:27] Registering override fallback regex
Warning - given vocab_size in params is unequal to tokenizer vocab size.
[cpuinfo_utils.cpp:71] Reading file /sys/devices/soc0/image_version
[cpuinfo_utils.cpp:87] Failed to open midr file /sys/devices/soc0/image_version
[cpuinfo_utils.cpp:100] Reading file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1
[cpuinfo_utils.cpp:109] Failed to open midr file /sys/devices/system/cpu/cpu0/regs/identification/midr_el1
[cpuinfo_utils.cpp:125] CPU info and manual query on # of cpus dont match.
<think>
</think>
To find 15% of 80, you can multiply 80 by 0.15:
$$80 \times 0.15 = 12$$
So, **15% of 80 is 12**.
Prefill time: 15.471495151519775
Generation tok/s: 2.097784149345492
Response: [248068, 271, 248069, 271, 1206, 1423, 220, 16, 20, 4, 314, 220, 23, 15, 11, 488, 628, 29283, 220, 23, 15, 539, 220, 15, 13, 16, 20, 25, 271, 13682, 23, 15, 1088, 14695, 220, 15, 13, 16, 20, 283, 220, 16, 17, 13682, 271, 4272, 11, 2972, 16, 20, 4, 314, 220, 23, 15, 369, 220, 16, 17, 159034, 248046]
Seems like the state is already zeroed here?
https://github.com/pytorch/executorch/blob/main/examples/models/llama/attention.py#L720
| "k_cache", | ||
| "v_cache", |
There was a problem hiding this comment.
InitializedMutableBufferPass matches patterns by substring. Using "k_cache"/"v_cache" here will also match other buffer names like "k_cache_scales", "k_cache_zero_points", or "past_k_caches_*" if present in the exported graph, potentially initializing/serializing more (large) buffers than intended. If you only mean the primary caches, consider narrowing the patterns to something less collision-prone (e.g., include a delimiter or full buffer name) or splitting by known FQN fragments.
| "k_cache", | |
| "v_cache", | |
| ".k_cache", | |
| ".v_cache", |
lucylq
left a comment
There was a problem hiding this comment.
I am not sure if this PR is necessary - seems like the recurrent state is initialized in the code _maybe_reset_state. I'm not sure kv-cache needs to be initialized. We get some binary size savings by not storing initial-value 0s in the .pte file but it's not as much as I expected only 5mb.
Let me know what you think @Phineas1500 . I added a longer comment in the code.
@lucylq I agree with you, the operations are unnecessary. I think this PR can be closed, since nothing else in it is important. I think I'm going to make another PR addressing this concern #17800 (comment) It'll improve performance if I replace the loop in attention.py that runs once per token (_recurrent_gated_delta_rule) with a custom op. Let me know if you think that sounds good or not. |
@Phineas1500 sounds good, this PR can be closed then. Yes it would be great if you could optimize the recurrence! Do you also have time to work on quantization - perhaps 8da4w with xnnpack? (if not, I'll take a look). |
@lucylq sure, happy to look at both. should i create a separate PR for each? |
yes please! Thanks so much, appreciate it 🙏 |
Summary
Why
Qwen3.5 uses internal mutable state (KV + DeltaNet recurrent/conv buffers). Initializing these buffers at export time avoids uninitialized mutable-buffer state and makes startup behavior deterministic.
Test Plan
Stacking